#! /usr/bin/env python3

# Cosinus distance is used in face recognition and person reidentification demo
from pathlib import Path
import torch
from torch import nn
import blobconverter
import onnx
from onnxsim import simplify

name = "cos_dist"

class Model(nn.Module):
    def forward(self, a, b):
        w12 =  torch.matmul(a, b)
        print('w12', w12)
        w1 = torch.matmul(a, a)
        w2 = torch.matmul(b, b)
        # Values w12,w1,w2 can be up to 10000. Multiplying
        # w1*w2 would mean Inf at FP16 (max value 65k),
        # so we divide both values by 1000
        # RESULT: Cos distance will be 0..1000.0
        mul = torch.multiply(torch.div(w1, 1000),torch.div(w2, 1000))
        print('mul',mul)
        # const = torch.tensor(1e-8) # To avoid division by 0
        n12 = torch.sqrt(mul)
        print('n12',n12)
        return torch.div(w12, n12)

# Define the expected input shape (dummy input)
shape = (256)
X = torch.ones(shape, dtype=torch.float32)

path = Path("out/")
path.mkdir(parents=True, exist_ok=True)
onnx_path = str(path / (name + '.onnx'))

print(f"Writing to {onnx_path}")
torch.onnx.export(
    Model(),
    (X, X),
    onnx_path,
    opset_version=12,
    do_constant_folding=True,
    input_names = ['a', 'b'], # Optional
    output_names = ['output'], # Optional
)

onnx_simplified_path = str(path / (name + '_simplified.onnx'))
# Use onnx-simplifier to simplify the onnx model
onnx_model = onnx.load(onnx_path)
model_simp, check = simplify(onnx_model)
onnx.save(model_simp, onnx_simplified_path)


# Use blobconverter to convert onnx->IR->blob
blobconverter.from_onnx(
    model=onnx_simplified_path,
    data_type="FP16",
    shaves=6,
    use_cache=False,
    output_dir="../models",
    optimizer_params=[],
    compile_params = [] # To avoid `-ip U8` by default
)

# Testing if this would work
# Cos dist [294.25] <= From VPU
# From CPU at FP32: 294.4874
# a = [-0.97509765625, -4.7265625, -0.6650390625, 6.84765625, -1.0419921875, 3.5625, -2.28125, 0.381103515625, 0.673828125, -0.96923828125, -3.283203125, 5.4609375, 1.0703125, 1.3125, 3.3671875, -0.442626953125, 1.5517578125, -1.189453125, 3.279296875, 0.2425537109375, -8.796875, -1.7314453125, 1.5947265625, -2.701171875, -1.2138671875, 0.2244873046875, -5.359375, 0.493408203125, -2.1328125, -0.404541015625, 1.78125, 5.86328125, 0.87255859375, 2.50390625, -1.3427734375, -2.46875, 4.421875, -0.408203125, -1.9814453125, 2.22265625, -4.45703125, 2.525390625, 2.0390625, -2.96484375, -0.353759765625, 7.42578125, 1.986328125, -4.2890625, -2.904296875, -2.998046875, 1.44140625, 5.56640625, 1.68359375, 1.4443359375, -2.763671875, 5.6484375, 3.419921875, -2.169921875, 1.7822265625, 0.27587890625, -4.81640625, 5.6953125, -0.247802734375, -3.150390625, -0.12335205078125, 0.31591796875, -2.021484375, 1.0888671875, 5.21484375, 0.61767578125, -3.650390625, 4.80078125, -1.6416015625, -0.86376953125, 0.59912109375, -5.453125, 1.3427734375, 0.6962890625, 3.107421875, 3.044921875, 3.875, 2.833984375, -2.0390625, 2.32421875, -0.80322265625, 3.33203125, -2.28125, -4.16015625, -1.94140625, -2.82421875, 2.55078125, 1.853515625, -4.3125, -2.16796875, -3.89453125, 1.6845703125, 2.13671875, -1.775390625, 3.333984375, -0.54248046875, -5.91796875, 1.2451171875, -4.1015625, -1.8828125, -5.06640625, -3.2578125, 1.7294921875, -3.2421875, -5.5703125, -0.395751953125, -1.4189453125, -1.978515625, -0.193359375, 0.364013671875, 4.7265625, -1.0556640625, 5.375, 3.904296875, 0.8955078125, -0.78955078125, 0.2303466796875, 1.7919921875, 1.7568359375, 4.16015625, 3.798828125, -0.9501953125, -0.6005859375, 6.86328125, -2.0703125, 2.89453125, 0.65771484375, -4.69140625, -3.31640625, -3.525390625, 0.73583984375, 2.9375, -1.248046875, -0.2275390625, -1.4814453125, 0.1097412109375, 1.98828125, -3.568359375, -0.88134765625, 1.0869140625, 0.08782958984375, -0.82080078125, 4.171875, -4.96875, -1.3681640625, -0.9521484375, 0.55908203125, -1.6513671875, -1.4482421875, -3.1796875, 0.8603515625, 2.744140625, 3.568359375, -1.0087890625, -1.91796875, -0.26220703125, -6.20703125, 2.3359375, -1.3759765625, -2.98828125, -1.380859375, -1.404296875, 6.26953125, 6.6328125, 3.236328125, -0.1300048828125, -3.291015625, 1.228515625, -5.90234375, -0.403564453125, -0.79150390625, 9.25, -1.2109375, 7.66015625, 1.4609375, 3.318359375, 5.88671875, 0.1644287109375, -1.60546875, 0.5654296875, 3.330078125, 3.94140625, 4.4921875, -3.03515625, 1.9599609375, 2.3203125, -2.849609375, -0.4072265625, 0.71826171875, 0.0096435546875, 2.32421875, 0.068359375, 4.84765625, -1.1904296875, -1.1298828125, -0.54345703125, 5.19921875, 0.419921875, -1.9599609375, -0.292724609375, -4.1328125, -0.63427734375, -7.41015625, 2.26171875, -2.236328125, 2.439453125, -4.58984375, 3.134765625, -0.91748046875, -3.8984375, -0.447509765625, 2.466796875, 2.306640625, -0.91650390625, -5.1796875, 3.8515625, -1.1943359375, -0.49951171875, 0.986328125, 2.375, 0.74462890625, 6.17578125, -2.216796875, -2.5078125, -2.251953125, 5.91015625, 4.06640625, -5.9296875, -2.86328125, -3.076171875, -0.0234375, 0.06976318359375, 0.74365234375, -1.1181640625, 0.2493896484375, 4.2890625, -0.94580078125, -2.380859375, 1.6875, 1.1123046875, 0.755859375, -1.51171875, -3.392578125, -1.4375, -2.1875, 3.369140625, 8.1328125, -1.8505859375, -4.38671875, -6.08984375, -0.53564453125, 5.80078125]
# b = [2.7265625, -4.71484375, -0.92138671875, 5.85546875, 2.1640625, 2.890625, 1.83984375, -3.478515625, -1.0, -2.392578125, -2.048828125, -1.828125, -5.6796875, -2.4453125, 4.57421875, -0.94287109375, -1.2373046875, -6.19140625, -0.264892578125, 2.865234375, -2.7421875, 3.9609375, -3.64453125, -3.37890625, -6.86328125, -1.4912109375, -5.80859375, 0.2451171875, -3.36328125, 0.60205078125, 4.8515625, 4.546875, 2.14453125, -4.37109375, 0.005126953125, -9.390625, 4.984375, -3.52734375, 1.822265625, 2.5, 5.609375, 6.25390625, 1.1748046875, -2.244140625, -3.880859375, -2.6484375, 4.00390625, 3.666015625, 5.74609375, -0.7470703125, -2.310546875, -1.2509765625, -0.5693359375, 4.1640625, -0.385986328125, 2.775390625, -2.234375, -2.396484375, -0.1285400390625, 3.77734375, -2.625, -2.2265625, 2.255859375, -0.869140625, 0.79052734375, -4.24609375, 0.053619384765625, -5.1015625, 1.9521484375, 2.37890625, -4.609375, -5.3046875, -0.2464599609375, -2.423828125, 5.41015625, -1.24609375, -4.3515625, 3.142578125, -0.74560546875, 2.322265625, 2.970703125, 3.994140625, 1.671875, 0.55810546875, 1.29296875, 0.363525390625, -0.759765625, -1.0009765625, -4.49609375, -3.25390625, -0.55419921875, 2.810546875, -0.92529296875, 0.395751953125, -2.125, -1.9169921875, -7.21875, -3.310546875, 3.033203125, 0.53564453125, -3.78515625, -3.658203125, -2.865234375, -4.96484375, -1.6787109375, 0.81298828125, -6.0546875, -4.2109375, -2.75390625, 0.62158203125, 0.39892578125, -3.24609375, 3.599609375, -1.455078125, 6.1953125, 4.34765625, -3.00390625, 2.623046875, -0.1405029296875, -4.47265625, 1.0361328125, -5.09375, 1.08984375, -0.258056640625, -3.66796875, 0.46826171875, 1.5859375, 1.470703125, -2.166015625, -1.1650390625, 1.224609375, -2.17578125, -4.4140625, -1.50390625, 3.826171875, -3.099609375, -1.8447265625, -0.241943359375, -2.9921875, 2.865234375, -1.3544921875, -3.298828125, -0.27734375, 0.461669921875, -0.26806640625, 4.45703125, 1.302734375, -7.08984375, 1.6337890625, 4.58984375, -0.198486328125, 0.46337890625, -0.266845703125, -4.640625, 2.552734375, 3.32421875, -1.5595703125, -5.81640625, 0.60595703125, -0.90478515625, 0.198974609375, 2.982421875, 1.8056640625, 1.0947265625, 2.95703125, 3.3125, 0.0240478515625, 2.384765625, 0.8720703125, -1.205078125, 3.751953125, -0.9501953125, -9.7734375, 3.466796875, 1.9775390625, 2.962890625, 6.2734375, 3.357421875, -0.1767578125, 5.55078125, -0.037994384765625, -2.365234375, -0.24267578125, 3.59765625, 0.2484130859375, 2.103515625, 0.93310546875, -0.6953125, -1.244140625, -2.0078125, -1.9638671875, -0.71337890625, -2.982421875, 1.8896484375, 5.66015625, -1.4208984375, 2.45703125, -1.6865234375, -1.462890625, 0.67919921875, -0.4990234375, 3.861328125, -0.7314453125, -5.00390625, -3.2734375, 0.1248779296875, 0.48974609375, 1.0693359375, -2.87109375, 1.767578125, -3.103515625, 3.359375, 3.36328125, -1.8720703125, 0.62890625, -1.2490234375, -3.1015625, 1.3974609375, 1.1357421875, -4.86328125, 4.28125, 4.45703125, 0.34619140625, 0.212646484375, 0.4609375, -3.07421875, -2.75, -5.01171875, -2.48046875, 1.419921875, -0.74560546875, -2.34765625, 1.578125, -1.9267578125, -0.257568359375, -3.28515625, -1.3203125, 0.2193603515625, 0.8857421875, -0.94921875, -1.4130859375, -0.61376953125, 4.87890625, 0.52392578125, 1.8203125, -4.23046875, 2.73046875, -3.5390625, 1.716796875, -0.9853515625, 4.75, 3.76171875, -2.208984375, -6.75, 2.46875, 5.5]
# import numpy as np
# test = Model()
# npa = torch.from_numpy(np.array(a, dtype=np.float32))
# npb = torch.from_numpy(np.array(b, dtype=np.float32))

# print(test.forward(npa, npb))

# # function we use on the host to check cosinus distance
# def cos_dist(a, b):
#     return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

# print('Numpy dist', cos_dist(npa, npb))